package main

import (
	"bytes"
	"flag"
	"fmt"
	"go/ast"
	"go/format"
	"go/parser"
	"go/token"
	"io"
	"os"
	"reflect"
	"sort"
	"strconv"
	"strings"
)

var supported = flag.Bool("supported", false, "genRouter supported.go")
var output = flag.String("o", "", "output file")
var pkg = flag.String("pkg", "", "package name")
var src = flag.String("path", "", "source file")

type Param struct {
	Name    string
	Type    string
	Default string
}

type Router struct {
	Func    string
	Path    []string
	PathV11 []string // v11 only
	PathV12 []string // v12 only
	Params  []Param
}

type generator struct {
	out io.Writer
}

const (
	PathAll = 0
	PathV11 = 11
	PathV12 = 12
)

func (g *generator) WriteString(s string) {
	io.WriteString(g.out, s)
}

func (g *generator) writef(format string, a ...any) {
	fmt.Fprintf(g.out, format, a...)
}

func (g *generator) header() {
	g.WriteString("// Code generated by cmd/api-generator. DO NOT EDIT.\n\n")
	g.writef("package %s\n\n", *pkg)
}

func (g *generator) genRouter(routers []Router) {
	g.WriteString("import (\n\n")
	g.WriteString("\"github.com/Mrs4s/go-cqhttp/coolq\"\n")
	g.WriteString("\"github.com/Mrs4s/go-cqhttp/global\"\n")
	g.WriteString("\"github.com/Mrs4s/go-cqhttp/pkg/onebot\"\n")
	g.WriteString(")\n\n")
	g.WriteString(`func (c *Caller) call(action string, spec *onebot.Spec, p Getter) global.MSG {`)
	genVer := func(path int) {
		g.writef(`if spec.Version == %d {
		switch action {
	`, path)
		for _, router := range routers {
			g.router(router, path)
		}
		g.WriteString("}}\n")
	}
	genVer(PathV11)
	genVer(PathV12)
	// generic path
	g.WriteString("switch action {\n")
	for _, router := range routers {
		g.router(router, PathAll)
	}
	g.WriteString("}\n")
	g.WriteString("return coolq.Failed(404, \"API_NOT_FOUND\", \"API不存在\")}")
}

func (g *generator) router(router Router, pathVersion int) {
	path := router.Path
	if pathVersion == PathV11 {
		path = router.PathV11
	}
	if pathVersion == PathV12 {
		path = router.PathV12
	}
	if len(path) == 0 {
		return
	}

	g.WriteString(`case `)
	for i, p := range path {
		if i != 0 {
			g.WriteString(`, `)
		}
		g.WriteString(strconv.Quote(p))
	}
	g.WriteString(":\n")

	for i, p := range router.Params {
		if p.Type == "*onebot.Spec" {
			continue
		}
		if p.Default == "" {
			v := "p.Get(" + strconv.Quote(p.Name) + ")"
			g.writef("p%d := %s\n", i, conv(v, p.Type))
		} else {
			g.writef("p%d := %s\n", i, p.Default)
			g.writef("if pt := p.Get(%s); pt.Exists() {\n", strconv.Quote(p.Name))
			g.writef("p%d = %s\n}\n", i, conv("pt", p.Type))
		}
	}

	g.WriteString("\t\treturn c.bot." + router.Func + "(")
	for i, p := range router.Params {
		if i != 0 {
			g.WriteString(", ")
		}
		if p.Type == "*onebot.Spec" {
			g.WriteString("spec")
			continue
		}
		g.writef("p%d", i)
	}
	g.WriteString(")\n")
}

func conv(v, t string) string {
	switch t {
	default:
		panic("unsupported type: " + t)
	case "gjson.Result", "*onebot.Spec":
		return v
	case "int64":
		return v + ".Int()"
	case "bool":
		return v + ".Bool()"
	case "string":
		return v + ".String()"
	case "int32", "int":
		return t + "(" + v + ".Int())"
	case "uint64":
		return v + ".Uint()"
	case "uint32":
		return "uint32(" + v + ".Uint())"
	case "uint16":
		return "uint16(" + v + ".Uint())"
	}
}

func main() {
	var routers []Router
	flag.Parse()
	fset := token.NewFileSet()
	for _, s := range strings.Split(*src, ",") {
		file, err := parser.ParseFile(fset, s, nil, parser.ParseComments)
		if err != nil {
			panic(err)
		}

		for _, decl := range file.Decls {
			switch decl := decl.(type) {
			case *ast.FuncDecl:
				if !decl.Name.IsExported() || decl.Recv == nil ||
					typeName(decl.Recv.List[0].Type) != "*CQBot" {
					continue
				}
				router := Router{Func: decl.Name.Name}

				// compute params
				for _, p := range decl.Type.Params.List {
					typ := typeName(p.Type)
					for _, name := range p.Names {
						router.Params = append(router.Params, Param{Name: snakecase(name.Name), Type: typ})
					}
				}

				for _, comment := range decl.Doc.List {
					annotation, args := match(comment.Text)
					switch annotation {
					case "route":
						for _, route := range strings.Split(args, ",") {
							router.Path = append(router.Path, unquote(route))
						}
					case "route11":
						for _, route := range strings.Split(args, ",") {
							router.PathV11 = append(router.PathV11, unquote(route))
						}
					case "route12":
						for _, route := range strings.Split(args, ",") {
							router.PathV12 = append(router.PathV12, unquote(route))
						}
					case "default":
						for name, value := range parseMap(args, "=") {
							for i, p := range router.Params {
								if p.Name == name {
									router.Params[i].Default = convDefault(value, p.Type)
								}
							}
						}
					case "rename":
						for name, value := range parseMap(args, "->") {
							for i, p := range router.Params {
								if p.Name == name {
									router.Params[i].Name = value
								}
							}
						}
					}
					sort.Slice(router.Path, func(i, j int) bool {
						return router.Path[i] < router.Path[j]
					})
					sort.Slice(router.PathV11, func(i, j int) bool {
						return router.PathV11[i] < router.PathV11[j]
					})
					sort.Slice(router.PathV12, func(i, j int) bool {
						return router.PathV12[i] < router.PathV12[j]
					})
				}
				if router.Path != nil || router.PathV11 != nil || router.PathV12 != nil {
					routers = append(routers, router)
				} else {
					println(decl.Name.Name)
				}
			}
		}
	}

	sort.Slice(routers, func(i, j int) bool {
		path := func(r Router) string {
			if r.Path != nil {
				return r.Path[0]
			}
			if r.PathV11 != nil {
				return r.PathV11[0]
			}
			if r.PathV12 != nil {
				return r.PathV12[0]
			}
			return ""
		}
		return path(routers[i]) < path(routers[j])
	})

	out := new(bytes.Buffer)
	g := &generator{out: out}
	g.header()
	if *supported {
		g.genSupported(routers)
	} else {
		g.genRouter(routers)
	}
	source, err := format.Source(out.Bytes())
	if err != nil {
		panic(err)
	}
	err = os.WriteFile(*output, source, 0o644)
	if err != nil {
		panic(err)
	}
}

func unquote(s string) string {
	switch s[0] {
	case '"':
		s, _ = strconv.Unquote(s)
	case '`':
		s = strings.Trim(s, "`")
	}
	return s
}

func parseMap(input string, sep string) map[string]string {
	out := make(map[string]string)
	for _, arg := range strings.Split(input, ",") {
		k, v, ok := strings.Cut(arg, sep)
		if !ok {
			out[k] = "true"
		}
		k = strings.TrimSpace(k)
		v = unquote(strings.TrimSpace(v))
		out[k] = v
	}
	return out
}

func match(text string) (string, string) {
	text = strings.TrimPrefix(text, "//")
	text = strings.TrimSpace(text)
	if !strings.HasPrefix(text, "@") || !strings.HasSuffix(text, ")") {
		return "", ""
	}
	text = strings.Trim(text, "@)")
	cmd, args, ok := strings.Cut(text, "(")
	if !ok {
		return "", ""
	}
	return cmd, unquote(args)
}

// some abbreviations need translation before transforming ro snake case
var replacer = strings.NewReplacer("ID", "Id")

func snakecase(s string) string {
	s = replacer.Replace(s)
	t := make([]byte, 0, 32)
	for i := 0; i < len(s); i++ {
		c := s[i]
		if ('a' <= c && c <= 'z') || ('0' <= c && c <= '9') {
			t = append(t, c)
		} else {
			t = append(t, '_')
			t = append(t, c^0x20)
		}
	}
	return string(t)
}

func convDefault(s string, t string) string {
	switch t {
	case "bool":
		if s == "true" {
			return s
		}
	case "uint32":
		if s != "0" {
			return t + "(" + s + ")"
		}
	default:
		panic("unhandled default value type:" + t)
	}
	return ""
}

func typeName(x ast.Node) string {
	switch x := x.(type) {
	case *ast.Ident:
		return x.Name
	case *ast.SelectorExpr:
		return typeName(x.X) + "." + x.Sel.Name
	case *ast.StarExpr:
		return "*" + typeName(x.X)
	default:
		panic("unhandled type: " + reflect.TypeOf(x).String())
	}
}
